{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fc5bde60-1899-461d-8083-3ee04ac7c099",
   "metadata": {},
   "source": [
    "# 模型推理 - 使用 QLoRA 微调后的 ChatGLM-6B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3de57c20-7d62-4605-80d4-e7fb94775303",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['HF_HOME'] = '/mnt/huggingface/hf'\n",
    "os.environ['HF_HUB_CACHE'] = '/mnt/huggingface/hf/hub'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3292b88c-91f0-48d2-91a5-06b0830c7e70",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/envs/langchain/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig\n",
    "\n",
    "# 模型ID或本地路径\n",
    "model_name_or_path = 'THUDM/chatglm3-6b'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9f81454c-24b2-4072-ab05-b25f9b120ae6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 7/7 [00:04<00:00,  1.44it/s]\n"
     ]
    }
   ],
   "source": [
    "_compute_dtype_map = {\n",
    "    'fp32': torch.float32,\n",
    "    'fp16': torch.float16,\n",
    "    'bf16': torch.bfloat16\n",
    "}\n",
    "\n",
    "# QLoRA 量化配置\n",
    "q_config = BitsAndBytesConfig(load_in_4bit=True,\n",
    "                              bnb_4bit_quant_type='nf4',\n",
    "                              bnb_4bit_use_double_quant=True,\n",
    "                              bnb_4bit_compute_dtype=_compute_dtype_map['bf16'])\n",
    "\n",
    "# 加载量化后模型(与微调的 revision 保持一致）\n",
    "base_model = AutoModel.from_pretrained(model_name_or_path,\n",
    "                                      quantization_config=q_config,\n",
    "                                      device_map='auto',\n",
    "                                      trust_remote_code=True,\n",
    "                                      revision='b098244')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d488846f-41bb-4fe6-9f09-0f392f3b39e6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ChatGLMForConditionalGeneration(\n",
       "  (transformer): ChatGLMModel(\n",
       "    (embedding): Embedding(\n",
       "      (word_embeddings): Embedding(65024, 4096)\n",
       "    )\n",
       "    (rotary_pos_emb): RotaryEmbedding()\n",
       "    (encoder): GLMTransformer(\n",
       "      (layers): ModuleList(\n",
       "        (0-27): 28 x GLMBlock(\n",
       "          (input_layernorm): RMSNorm()\n",
       "          (self_attention): SelfAttention(\n",
       "            (query_key_value): Linear4bit(in_features=4096, out_features=4608, bias=True)\n",
       "            (core_attention): CoreAttention(\n",
       "              (attention_dropout): Dropout(p=0.0, inplace=False)\n",
       "            )\n",
       "            (dense): Linear4bit(in_features=4096, out_features=4096, bias=False)\n",
       "          )\n",
       "          (post_attention_layernorm): RMSNorm()\n",
       "          (mlp): MLP(\n",
       "            (dense_h_to_4h): Linear4bit(in_features=4096, out_features=27392, bias=False)\n",
       "            (dense_4h_to_h): Linear4bit(in_features=13696, out_features=4096, bias=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "      (final_layernorm): RMSNorm()\n",
       "    )\n",
       "    (output_layer): Linear(in_features=4096, out_features=65024, bias=False)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "base_model.requires_grad_(False)\n",
    "base_model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7e4270e2-c827-450e-bf27-7cb43a97f8f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,\n",
    "                                          trust_remote_code=True,\n",
    "                                          revision='b098244')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63408b60-876e-4eda-b501-90f842cca002",
   "metadata": {},
   "source": [
    "## 使用原始 ChatGLM3-6B 模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6ef405cf-7d77-41a6-a07b-c6c768ee30cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_text = \"解释下乾卦是什么？\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "566ed80e-828b-4105-b6e6-49de8905c991",
   "metadata": {},
   "outputs": [],
   "source": [
    "response, history = base_model.chat(tokenizer, query=input_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "6cee217e-f276-4c2f-94e7-69afb6d541a6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "乾卦是八卦之一，也是八宫图之一，它的卦象是由三个阳爻夹一个阴爻构成，象征着天、云和雷等自然现象。乾卦的含义有：\n",
      "\n",
      "1. 阳刚之气：乾卦代表着阳刚之气，象征着天、云和雷等自然现象，以及坚毅、果敢、积极向上的品质。\n",
      "2. 力量和权威：乾卦也象征着力量和权威，表示有 strength and jurisdiction.\n",
      "3. 创造力和领导能力：乾卦还代表着创造力和领导能力，表示有 ability to create and lead.\n",
      "4. 进展和进步：乾卦也象征着进展和进步，表示 things will improve and progress will be made.\n",
      "\n",
      "在五行中，乾卦属于木，代表着生长、茂盛和勃发。在八宫图中，乾卦位于北方，与事业、努力、积极向上等有关。\n"
     ]
    }
   ],
   "source": [
    "print(response)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3db3245d-037d-4fe5-ac0d-cc5e82742399",
   "metadata": {},
   "source": [
    "#### 询问一个64卦相关问题（应该不在 ChatGLM3-6B 预训练数据中）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "bbe1395f-39c2-4759-ae81-90ef3bcfae47",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "讼卦是八卦之一，它的卦象是由两个阴爻夹一个阳爻构成，象征着水、太平和和谐等自然现象。讼卦的含义有：\n",
      "\n",
      "1. 和谐与和平：讼卦代表着和谐与和平，表示通过沟通和协商，可以使矛盾得到化解，达到和谐和平的状态。\n",
      "2. 法律与规则：讼卦也象征着法律和规则，表示在解决问题时应该遵守法律法规，以及制定合理的规则和标准。\n",
      "3. 挫折和困难：讼卦还代表着挫折和困难，表示在实现目标的过程中可能会遇到困难和挫折，需要有毅力和耐心去克服。\n",
      "4. 变化和适应：讼卦也象征着变化和适应，表示面对不断变化的环境和情况，需要灵活变通，适应新的变化。\n",
      "\n",
      "在五行中，讼卦属于水，代表着温柔、柔软和流动。在八宫图中，讼卦位于西方，与法律、规则、变化等有关。\n"
     ]
    }
   ],
   "source": [
    "response, history = base_model.chat(tokenizer, query=\"周易中的讼卦是什么？\", history=history)\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "342b3659-d644-4232-8af1-f092e733bf40",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "6d23e720-dee1-4b43-a298-0cbe1d8ad11d",
   "metadata": {},
   "source": [
    "## 使用微调后的 ChatGLM3-6B"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6bcfc5a2-41ed-405c-a31c-dca4fbb67425",
   "metadata": {},
   "source": [
    "### 加载 QLoRA Adapter(Epoch=3, automade-dataset(fixed)) - 请根据训练时间戳修改 timestamp "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "9c767c67-42aa-459c-a096-e226226c359b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from peft import PeftModel, PeftConfig\n",
    "\n",
    "epochs = 3\n",
    "# timestamp = \"20240118_164514\"\n",
    "timestamp = \"20240419_104829\"\n",
    "\n",
    "peft_model_path = f\"models/{model_name_or_path}-epoch{epochs}-{timestamp}\"\n",
    "\n",
    "config = PeftConfig.from_pretrained(peft_model_path)\n",
    "qlora_model = PeftModel.from_pretrained(base_model, peft_model_path)\n",
    "training_tag=f\"ChatGLM3-6B(Epoch=3, automade-dataset(fixed))-{timestamp}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "24a5d22b-2c94-4dcf-8135-18d78f98755f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compare_chatglm_results(query, base_model, qlora_model, training_tag):\n",
    "    base_response, base_history = base_model.chat(tokenizer, query)\n",
    "\n",
    "    inputs = tokenizer(query, return_tensors=\"pt\").to(0)\n",
    "    ft_out = qlora_model.generate(**inputs, max_new_tokens=512)\n",
    "    ft_response = tokenizer.decode(ft_out[0], skip_special_tokens=True)\n",
    "    \n",
    "    print(f\"问题：{query}\\n\\n原始输出：\\n{base_response}\\n\\n\\n微调后（{training_tag}）：\\n{ft_response}\")\n",
    "    return base_response, ft_response"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "062cd62e-69f9-4605-8c83-e468f71ef3d3",
   "metadata": {},
   "source": [
    "### 微调前后效果对比"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "7db16cd5-0bb5-44ab-b861-d9ca6a4970c2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "问题：解释下乾卦是什么？\n",
      "\n",
      "原始输出：\n",
      "乾卦，又称天卦，是易经中的第一卦，象征着天和龙（德才的君子）及其它的君子。乾卦的卦辞价为：天（德）刚健，运行不已，但由于其刚健，有时会与阳刚之卦--坤卦相互冲击。\n",
      "\n",
      "乾卦的卦象为：天雷声势刚健，加上了君子们的象，象征着天和德行结合，君子以天为德，刚健运行。这一卦象象征着天高且运行不已，具有天德，寓意着君子们应当具有德行刚健、美恶不辨的品德。\n",
      "\n",
      "乾卦，犹如天雷，预示着天高且运行不止，象征着德行刚健。君子观此卦象，从而以天为榜样，追求德行刚健。同时，天高且运行不止的卦象也警示着君子们需要注意行动的规范，不使德行受到损害。\n",
      "\n",
      "在事业、经商、求名、婚恋和决策等方面，乾卦提醒人们要具有德行刚健、积极向上，坚持正道，勇敢面对困难，不怕挑战，从而取得成功。同时，乾卦也提醒人们要追求德行刚健，不断追求进步，勇往直前，从而在事业上取得成功。\n",
      "\n",
      "\n",
      "微调后（ChatGLM3-6B(Epoch=3, automade-dataset(fixed))-20240419_104829）：\n",
      "[gMASK]sop 解释下乾卦是什么？ 在周易中，乾卦是六十四卦之首，象征着天和龙（德才的君子），代表着纯粹的阳和健，意味着兴盛和强健。乾卦的占卜结果是大吉大利，代表着吉利和贞正。《象辞》中说，“天道刚健，运行不已”，君子观此卦象，从而以天为法，自强不息。这一卦象教导人遵守天道的德行，表明事业如日中天，但也需警惕盛极必衰的道理。在事业、经商、求名、婚恋和决策等方面，都需要保持积极谨慎的态度，坚持正直、公允的实质，修养德行，积累知识，自强不息，方能克服困难，取得成功。\n"
     ]
    }
   ],
   "source": [
    "base_response, ft_response = compare_chatglm_results(\"解释下乾卦是什么？\", base_model, qlora_model, training_tag)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "7aa074bd-c819-4533-a10f-f3184dc9549a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "问题：周易中的讼卦是什么\n",
      "\n",
      "原始输出：\n",
      "{'name': '在周易中，讼卦是一个寓意深远的卦象。它由上卦乾（天）和下卦坎（水）组成，象征着天与水相违行，即天高水深，事理乖舛。讼卦代表着争论与争讼之象，虽有利可图，但需要警惕戒惧。初期事情会有中间的顺利，但后来会出现凶险。值得注意的是，占得此卦有利于会见贵族王公，但不利于涉水渡河。', 'content': '\\n讼卦在象辞中呈现出天高水深的局面，预示着事事都不宜强求，需要慎之又慎，没有从容处世的态度，很容易陷入争论纠纷。因此君子观此卦象，以杜绝争讼为意，从而在谋事之初必须慎之又慎。整体而言，讼卦提示人们应当谨慎处理事务，并避免与他人在争执中陷入对抗。在事业、经商、求名、婚恋和决策等方面，皆需要小心提防，退让让人，化解争端，方能避免不测之灾，最终得到吉祥和利益。'}\n",
      "\n",
      "\n",
      "微调后（ChatGLM3-6B(Epoch=3, automade-dataset(fixed))-20240419_104829）：\n",
      "[gMASK]sop 周易中的讼卦是什么样子的卦象, 在周易中，讼卦是一个寓意深远的卦象。它由上卦乾（天）和下卦坎（水）组成，象征着天与水相违行，即天高水深，事理乖舛。讼卦代表着争论与争讼之象，虽有利可图，但需要警惕戒惧。初期事情会有中间的顺利，但后来会出现凶险。值得注意的是，占得此卦有利于会见贵族王公，但不利于涉水渡河。\n",
      "\n",
      "讼卦在象辞中呈现出天高水深的局面，预示着事事都不宜强求，需要慎之又慎，没有从容处世的态度，很容易陷入争论纠纷。因此君子观此卦象，以杜绝争讼为意，从而在谋事之初必须慎之又慎。整体而言，讼卦提示人们应当谨慎处理事务，并避免与他人在争执中陷入对抗。在事业、经商、求名、婚恋和决策等方面，皆需要小心提防，退让让人，化解争端，方能避免不测之灾，最终得到吉祥和利益。\n"
     ]
    }
   ],
   "source": [
    "base_response, ft_response = compare_chatglm_results(\"周易中的讼卦是什么\", base_model, qlora_model, training_tag)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "d5a31554-40f1-4e6e-8240-f207c4a61b42",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "问题：师卦是什么？\n",
      "\n",
      "原始输出：\n",
      "在周易中，师卦是一个极具深意的卦象，它由两个异卦相叠组成：下卦坎（水）和上卦坤（地）。这一卦象代表“师”，即军队，寓意着兵力和农力的结合。在这里，坎卦象征着水和险难，而坤卦象征着地和顺从，暗示着通过将军事力量安置于民间，可以在必要时顺利调动。\n",
      "\n",
      "师卦的核心哲学是：虽然兵力代表着危险和战争，但其使用应当是圣人不得已而为之的最后手段。在正确的情况下，军事力量可以顺应形势，将危险转化为吉祥。因此，在军事策略上，此卦象征着出征将会顺利，无灾祸。\n",
      "\n",
      "师卦紧随讼卦（争讼卦），在《序卦》中解释为“讼必有众起，故受之以师”。这意味着争端激化至众多人群的参与，形成了类似军队的集体力量。在卜筮中，师卦出现，意味着困难重重，需要包容他人，艰苦努力，摒除一切困难。此外，在运势、事业、经商、求名、婚恋等方面也有着具体的解释和建议。\n",
      "\n",
      "\n",
      "微调后（ChatGLM3-6B(Epoch=3, automade-dataset(fixed))-20240419_104829）：\n",
      "[gMASK]sop 师卦是什么？ 在周易中，师卦是一个极具深意的卦象，它由两个异卦相叠组成：下卦坎（水）和上卦坤（地）。这一卦象代表“师”，即军队，寓意着兵力和农力的结合。在这里，坎卦象征着水和险难，而坤卦象征着地和顺从，暗示着通过将军事力量安置于民间，可以在必要时顺利调动。\n",
      "\n",
      "师卦的核心哲学是：虽然兵力代表着危险和战争，但其使用应当是圣人不得已而为之的最后手段。在正确的情况下，军事力量可以顺应形势，将危险转化为吉祥。因此，在军事策略上，此卦象征着出征将会顺利，无灾祸。\n",
      "\n",
      "师卦紧随讼卦（争讼卦），在《序卦》中解释为“讼必有众起，故受之以师”。这意味着争端激化至众多人群的参与，形成了类似军队的集体力量。在卜筮中，师卦出现，意味着困难重重，需要包容他人，艰苦努力，摒除一切困难。此外，在运势、事业、经商、求名、婚恋等方面也有着具体的解释和建议。\n"
     ]
    }
   ],
   "source": [
    "base_response, ft_response = compare_chatglm_results(\"师卦是什么？\", base_model, qlora_model, training_tag)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abae8a8e-00bb-4801-931a-c942206f0e2a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "7d48183f-f1dc-4171-b217-e269a5b9c1b9",
   "metadata": {},
   "source": [
    "## 其他模型（错误数据或训练参数）\n",
    "\n",
    "#### 加载 QLoRA Adapter(Epoch=3, automade-dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "46a0e881-a4f3-43b2-8a61-0ec543a538a7",
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "Can't find 'adapter_config.json' at 'models/THUDM/chatglm3-6b-epoch3'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mHFValidationError\u001b[0m                         Traceback (most recent call last)",
      "File \u001b[0;32m~/miniconda3/envs/langchain/lib/python3.10/site-packages/peft/config.py:108\u001b[0m, in \u001b[0;36mPeftConfigMixin.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, subfolder, **kwargs)\u001b[0m\n\u001b[1;32m    107\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 108\u001b[0m     config_file \u001b[38;5;241m=\u001b[39m \u001b[43mhf_hub_download\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    109\u001b[0m \u001b[43m        \u001b[49m\u001b[43mpretrained_model_name_or_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mCONFIG_NAME\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msubfolder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msubfolder\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mhf_hub_download_kwargs\u001b[49m\n\u001b[1;32m    110\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    111\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m:\n",
      "File \u001b[0;32m~/miniconda3/envs/langchain/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py:111\u001b[0m, in \u001b[0;36mvalidate_hf_hub_args.<locals>._inner_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    110\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m arg_name \u001b[38;5;129;01min\u001b[39;00m [\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrepo_id\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfrom_id\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mto_id\u001b[39m\u001b[38;5;124m\"\u001b[39m]:\n\u001b[0;32m--> 111\u001b[0m     \u001b[43mvalidate_repo_id\u001b[49m\u001b[43m(\u001b[49m\u001b[43marg_value\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    113\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m arg_name \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtoken\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m arg_value \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
      "File \u001b[0;32m~/miniconda3/envs/langchain/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py:159\u001b[0m, in \u001b[0;36mvalidate_repo_id\u001b[0;34m(repo_id)\u001b[0m\n\u001b[1;32m    158\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m repo_id\u001b[38;5;241m.\u001b[39mcount(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[0;32m--> 159\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m HFValidationError(\n\u001b[1;32m    160\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRepo id must be in the form \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mrepo_name\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m or \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mnamespace/repo_name\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m:\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    161\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mrepo_id\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m. Use `repo_type` argument if needed.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    162\u001b[0m     )\n\u001b[1;32m    164\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m REPO_ID_REGEX\u001b[38;5;241m.\u001b[39mmatch(repo_id):\n",
      "\u001b[0;31mHFValidationError\u001b[0m: Repo id must be in the form 'repo_name' or 'namespace/repo_name': 'models/THUDM/chatglm3-6b-epoch3'. Use `repo_type` argument if needed.",
      "\nDuring handling of the above exception, another exception occurred:\n",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[15], line 6\u001b[0m\n\u001b[1;32m      3\u001b[0m epochs \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m3\u001b[39m\n\u001b[1;32m      4\u001b[0m peft_model_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodels/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel_name_or_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m-epoch\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepochs\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m----> 6\u001b[0m config \u001b[38;5;241m=\u001b[39m \u001b[43mPeftConfig\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpeft_model_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m      7\u001b[0m qlora_model_e3 \u001b[38;5;241m=\u001b[39m PeftModel\u001b[38;5;241m.\u001b[39mfrom_pretrained(base_model, peft_model_path)\n\u001b[1;32m      8\u001b[0m training_tag \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mChatGLM3-6B(Epoch=3, automade-dataset)\u001b[39m\u001b[38;5;124m\"\u001b[39m\n",
      "File \u001b[0;32m~/miniconda3/envs/langchain/lib/python3.10/site-packages/peft/config.py:112\u001b[0m, in \u001b[0;36mPeftConfigMixin.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, subfolder, **kwargs)\u001b[0m\n\u001b[1;32m    108\u001b[0m         config_file \u001b[38;5;241m=\u001b[39m hf_hub_download(\n\u001b[1;32m    109\u001b[0m             pretrained_model_name_or_path, CONFIG_NAME, subfolder\u001b[38;5;241m=\u001b[39msubfolder, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mhf_hub_download_kwargs\n\u001b[1;32m    110\u001b[0m         )\n\u001b[1;32m    111\u001b[0m     \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m:\n\u001b[0;32m--> 112\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCan\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt find \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mCONFIG_NAME\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m at \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpretrained_model_name_or_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m    114\u001b[0m loaded_attributes \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39mfrom_json_file(config_file)\n\u001b[1;32m    116\u001b[0m \u001b[38;5;66;03m# TODO: this hack is needed to fix the following issue (on commit 702f937):\u001b[39;00m\n\u001b[1;32m    117\u001b[0m \u001b[38;5;66;03m# if someone saves a default config and loads it back with `PeftConfig` class it yields to\u001b[39;00m\n\u001b[1;32m    118\u001b[0m \u001b[38;5;66;03m# not loading the correct config class.\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    128\u001b[0m \u001b[38;5;66;03m# print(peft_config)\u001b[39;00m\n\u001b[1;32m    129\u001b[0m \u001b[38;5;66;03m# >>> PeftConfig(peft_type='ADALORA', auto_mapping=None, base_model_name_or_path=None, revision=None, task_type=None, inference_mode=False)\u001b[39;00m\n",
      "\u001b[0;31mValueError\u001b[0m: Can't find 'adapter_config.json' at 'models/THUDM/chatglm3-6b-epoch3'"
     ]
    }
   ],
   "source": [
    "from peft import PeftModel, PeftConfig\n",
    "\n",
    "epochs = 3\n",
    "peft_model_path = f\"models/{model_name_or_path}-epoch{epochs}\"\n",
    "\n",
    "config = PeftConfig.from_pretrained(peft_model_path)\n",
    "qlora_model_e3 = PeftModel.from_pretrained(base_model, peft_model_path)\n",
    "training_tag = f\"ChatGLM3-6B(Epoch=3, automade-dataset)\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f53196e-f523-4105-b04a-9ddab349cce1",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_response, ft_response = compare_chatglm_results(\"解释下乾卦是什么？\", base_model, qlora_model_e3, training_tag)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "046306ad-6afe-4ec9-ae55-3df04f61d8f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_response, ft_response = compare_chatglm_results(\"地水师卦是什么？\", base_model, qlora_model_e3, training_tag)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ab3c310-8cc8-428a-91fa-964b7a58df43",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_response, ft_response = compare_chatglm_results(\"周易中的讼卦是什么\", base_model, qlora_model_e3, training_tag)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cfffcc5-afa6-45c1-985a-a3eb86a0d1c8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "8169237c-55d3-4d91-9f6b-8dbe635f1844",
   "metadata": {},
   "source": [
    "#### 加载 QLoRA Adapter(Epoch=50, Overfit, handmade-dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72e6cc4f-c030-4107-b07a-6ef44f66a4b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from peft import PeftModel, PeftConfig\n",
    "\n",
    "epochs = 50\n",
    "peft_model_path = f\"models/{model_name_or_path}-epoch{epochs}\"\n",
    "\n",
    "config = PeftConfig.from_pretrained(peft_model_path)\n",
    "qlora_model_e50_handmade = PeftModel.from_pretrained(base_model, peft_model_path)\n",
    "training_tag = f\"ChatGLM3-6B(Epoch=50, handmade-dataset)\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d63b187-37be-4721-8959-098d0437c41d",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_response, ft_response = compare_chatglm_results(\"解释下乾卦是什么？\", base_model, qlora_model_e50_handmade, training_tag)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be5da80e-d1de-467f-a3bb-508d5a77a46d",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_response, ft_response = compare_chatglm_results(\"地水师卦\", base_model, qlora_model_e50_handmade, training_tag)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04f0eb9a-5075-4588-914a-2538bea801aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_response, ft_response = compare_chatglm_results(\"天水讼卦\", base_model, qlora_model_e50_handmade, training_tag)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
